{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "265ca5a2",
   "metadata": {},
   "source": [
    "# Torch Hub Detection Inference Tutorial\n",
    "\n",
    "In this tutorial you'll learn:\n",
    "- how to load a pretrained detection model using Torch Hub \n",
    "- run inference to detect actions in a demo video"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e684f3e4",
   "metadata": {},
   "source": [
    "## NOTE: \n",
    "At the moment tutorial only works if ran on local clone from the directory `pytorchvideo/tutorials/video_detection_example`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1084c2f",
   "metadata": {},
   "source": [
    "### Install and Import modules\n",
    "If `torch`, `torchvision`, `cv2` and `pytorchvideo` are not installed, run the following cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "130e7aaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import torch\n",
    "except ModuleNotFoundError:\n",
    "    !pip install torch torchvision\n",
    "    import os\n",
    "    import sys\n",
    "    import torch\n",
    "\n",
    "try:\n",
    "    import cv2\n",
    "except ModuleNotFoundError:\n",
    "    !pip install opencv-python\n",
    "    \n",
    "if torch.__version__=='1.6.0+cu101' and sys.platform.startswith('linux'):\n",
    "    !pip install pytorchvideo\n",
    "else:\n",
    "    need_pytorchvideo=False\n",
    "    try:\n",
    "        # Running notebook locally\n",
    "        import pytorchvideo\n",
    "    except ModuleNotFoundError:\n",
    "        need_pytorchvideo=True\n",
    "    if need_pytorchvideo:\n",
    "        # Install from GitHub\n",
    "        !pip install \"git+https://github.com/facebookresearch/pytorchvideo.git\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "74d4dee2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "import numpy as np\n",
    "\n",
    "import cv2\n",
    "import torch\n",
    "\n",
    "import detectron2\n",
    "from detectron2.config import get_cfg\n",
    "from detectron2 import model_zoo\n",
    "from detectron2.engine import DefaultPredictor\n",
    "\n",
    "import pytorchvideo\n",
    "from pytorchvideo.transforms.functional import (\n",
    "    uniform_temporal_subsample,\n",
    "    short_side_scale_with_boxes,\n",
    "    clip_boxes_to_image,\n",
    ")\n",
    "from torchvision.transforms._functional_video import normalize\n",
    "from pytorchvideo.data.ava import AvaLabeledVideoFramePaths\n",
    "from pytorchvideo.models.hub import slow_r50_detection # Another option is slowfast_r50_detection\n",
    "\n",
    "from visualization import VideoVisualizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6c8faad",
   "metadata": {},
   "source": [
    "## Load Model using Torch Hub API\n",
    "PyTorchVideo provides several pretrained models through Torch Hub. Available models are described in [model zoo documentation.](https://github.com/facebookresearch/pytorchvideo/blob/main/docs/source/model_zoo.md)\n",
    "\n",
    "Here we are selecting the slow_r50_detection model which was trained using a 4x16 setting on the Kinetics 400 dataset and \n",
    "fine tuned on AVA V2.2 actions dataset.\n",
    "\n",
    "NOTE: to run on GPU in Google Colab, in the menu bar selet: Runtime -> Change runtime type -> Harware Accelerator -> GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6bb9a374",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' # or 'cpu'\n",
    "video_model = slow_r50_detection(True) # Another option is slowfast_r50_detection\n",
    "video_model = video_model.eval().to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f21c0ea",
   "metadata": {},
   "source": [
    "## Load an off-the-shelf Detectron2 object detector\n",
    "\n",
    "We use the object detector to detect bounding boxes for the people. \n",
    "These bounding boxes later feed into our video action detection model.\n",
    "For more details, please refer to the Detectron2's object detection tutorials.\n",
    "\n",
    "To install Detectron2, please follow the instructions mentioned [here](https://github.com/facebookresearch/detectron2/blob/main/INSTALL.md)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7a5d5f4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = get_cfg()\n",
    "cfg.merge_from_file(model_zoo.get_config_file(\"COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml\"))\n",
    "cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.55  # set threshold for this model\n",
    "cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(\"COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml\")\n",
    "predictor = DefaultPredictor(cfg)\n",
    "\n",
    "# This method takes in an image and generates the bounding boxes for people in the image.\n",
    "def get_person_bboxes(inp_img, predictor):\n",
    "    predictions = predictor(inp_img.cpu().detach().numpy())['instances'].to('cpu')\n",
    "    boxes = predictions.pred_boxes if predictions.has(\"pred_boxes\") else None\n",
    "    scores = predictions.scores if predictions.has(\"scores\") else None\n",
    "    classes = np.array(predictions.pred_classes.tolist() if predictions.has(\"pred_classes\") else None)\n",
    "    predicted_boxes = boxes[np.logical_and(classes==0, scores>0.75 )].tensor.cpu() # only person\n",
    "    return predicted_boxes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8babcba",
   "metadata": {},
   "source": [
    "## Define the transformations for the input required by the model\n",
    "Before passing the video and bounding boxes into the model we need to apply some input transforms and sample a clip of the correct frame rate in the clip.\n",
    "\n",
    "Here, below we define a method that can pre-process the clip and bounding boxes. It generates inputs accordingly for both Slow (Resnet) and SlowFast models depending on the parameterization of the variable `slow_fast_alpha`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9cb1ec3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ava_inference_transform(\n",
    "    clip, \n",
    "    boxes,\n",
    "    num_frames = 4, #if using slowfast_r50_detection, change this to 32\n",
    "    crop_size = 256, \n",
    "    data_mean = [0.45, 0.45, 0.45], \n",
    "    data_std = [0.225, 0.225, 0.225],\n",
    "    slow_fast_alpha = None, #if using slowfast_r50_detection, change this to 4\n",
    "):\n",
    "\n",
    "    boxes = np.array(boxes)\n",
    "    ori_boxes = boxes.copy()\n",
    "\n",
    "    # Image [0, 255] -> [0, 1].\n",
    "    clip = uniform_temporal_subsample(clip, num_frames)\n",
    "    clip = clip.float()\n",
    "    clip = clip / 255.0\n",
    "\n",
    "    height, width = clip.shape[2], clip.shape[3]\n",
    "    # The format of boxes is [x1, y1, x2, y2]. The input boxes are in the\n",
    "    # range of [0, width] for x and [0,height] for y\n",
    "    boxes = clip_boxes_to_image(boxes, height, width)\n",
    "\n",
    "    # Resize short side to crop_size. Non-local and STRG uses 256.\n",
    "    clip, boxes = short_side_scale_with_boxes(\n",
    "        clip,\n",
    "        size=crop_size,\n",
    "        boxes=boxes,\n",
    "    )\n",
    "    \n",
    "    # Normalize images by mean and std.\n",
    "    clip = normalize(\n",
    "        clip,\n",
    "        np.array(data_mean, dtype=np.float32),\n",
    "        np.array(data_std, dtype=np.float32),\n",
    "    )\n",
    "    \n",
    "    boxes = clip_boxes_to_image(\n",
    "        boxes, clip.shape[2],  clip.shape[3]\n",
    "    )\n",
    "    \n",
    "    # Incase of slowfast, generate both pathways\n",
    "    if slow_fast_alpha is not None:\n",
    "        fast_pathway = clip\n",
    "        # Perform temporal sampling from the fast pathway.\n",
    "        slow_pathway = torch.index_select(\n",
    "            clip,\n",
    "            1,\n",
    "            torch.linspace(\n",
    "                0, clip.shape[1] - 1, clip.shape[1] // slow_fast_alpha\n",
    "            ).long(),\n",
    "        )\n",
    "        clip = [slow_pathway, fast_pathway]\n",
    "    \n",
    "    return clip, torch.from_numpy(boxes), ori_boxes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f26315c",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "Download the id to label mapping for the AVA V2.2 dataset on which the Torch Hub models were finetuned. \n",
    "This will be used to get the category label names from the predicted class ids.\n",
    "\n",
    "Create a visualizer to visualize and plot the results(labels + bounding boxes)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "6132a777",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://dl.fbaipublicfiles.com/pytorchvideo/data/class_names/ava_action_list.pbtxt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "39454172",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create an id to label name mapping\n",
    "label_map, allowed_class_ids = AvaLabeledVideoFramePaths.read_label_map('ava_action_list.pbtxt')\n",
    "# Create a video visualizer that can plot bounding boxes and visualize actions on bboxes.\n",
    "video_visualizer = VideoVisualizer(81, label_map, top_k=3, mode=\"thres\",thres=0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7086f4a0",
   "metadata": {},
   "source": [
    "## Load an example video\n",
    "We get an opensourced video off the web from WikiMedia. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "f27c302c",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://dl.fbaipublicfiles.com/pytorchvideo/projects/theatre.webm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b8bcc454",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Completed loading encoded video.\n"
     ]
    }
   ],
   "source": [
    "# Load the video\n",
    "encoded_vid = pytorchvideo.data.encoded_video.EncodedVideo.from_path('theatre.webm')\n",
    "print('Completed loading encoded video.')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3edb57ca",
   "metadata": {},
   "source": [
    "## Generate bounding boxes and action predictions for a 10 second clip in the video."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "500ebdfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Video predictions are generated at an internal of 1 sec from 90 seconds to 100 seconds in the video.\n",
    "time_stamp_range = range(90,100) # time stamps in video for which clip is sampled. \n",
    "clip_duration = 1.0 # Duration of clip used for each inference step.\n",
    "gif_imgs = []\n",
    "\n",
    "for time_stamp in time_stamp_range:    \n",
    "    print(\"Generating predictions for time stamp: {} sec\".format(time_stamp))\n",
    "    \n",
    "    # Generate clip around the designated time stamps\n",
    "    inp_imgs = encoded_vid.get_clip(\n",
    "        time_stamp - clip_duration/2.0, # start second\n",
    "        time_stamp + clip_duration/2.0  # end second\n",
    "    )\n",
    "    inp_imgs = inp_imgs['video']\n",
    "    \n",
    "    # Generate people bbox predictions using Detectron2's off the self pre-trained predictor\n",
    "    # We use the the middle image in each clip to generate the bounding boxes.\n",
    "    inp_img = inp_imgs[:,inp_imgs.shape[1]//2,:,:]\n",
    "    inp_img = inp_img.permute(1,2,0)\n",
    "    \n",
    "    # Predicted boxes are of the form List[(x_1, y_1, x_2, y_2)]\n",
    "    predicted_boxes = get_person_bboxes(inp_img, predictor) \n",
    "    if len(predicted_boxes) == 0: \n",
    "        print(\"Skipping clip no frames detected at time stamp: \", time_stamp)\n",
    "        continue\n",
    "        \n",
    "    # Preprocess clip and bounding boxes for video action recognition.\n",
    "    inputs, inp_boxes, _ = ava_inference_transform(inp_imgs, predicted_boxes.numpy())\n",
    "    # Prepend data sample id for each bounding box. \n",
    "    # For more details refere to the RoIAlign in Detectron2\n",
    "    inp_boxes = torch.cat([torch.zeros(inp_boxes.shape[0],1), inp_boxes], dim=1)\n",
    "    \n",
    "    # Generate actions predictions for the bounding boxes in the clip.\n",
    "    # The model here takes in the pre-processed video clip and the detected bounding boxes.\n",
    "    if isinstance(inputs, list):\n",
    "        inputs = [inp.unsqueeze(0).to(device) for inp in inputs]\n",
    "    else:\n",
    "        inputs = inputs.unsqueeze(0).to(device)\n",
    "    preds = video_model(inputs, inp_boxes.to(device))\n",
    "\n",
    "    preds= preds.to('cpu')\n",
    "    # The model is trained on AVA and AVA labels are 1 indexed so, prepend 0 to convert to 0 index.\n",
    "    preds = torch.cat([torch.zeros(preds.shape[0],1), preds], dim=1)\n",
    "    \n",
    "    # Plot predictions on the video and save for later visualization.\n",
    "    inp_imgs = inp_imgs.permute(1,2,3,0)\n",
    "    inp_imgs = inp_imgs/255.0\n",
    "    out_img_pred = video_visualizer.draw_clip_range(inp_imgs, preds, predicted_boxes)\n",
    "    gif_imgs += out_img_pred\n",
    "\n",
    "print(\"Finished generating predictions.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8da24031",
   "metadata": {},
   "source": [
    "## Save predictions as video\n",
    "The generated video consists of bounding boxes with predicted actions for each bounding box."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c4ae73fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "height, width = gif_imgs[0].shape[0], gif_imgs[0].shape[1]\n",
    "\n",
    "vide_save_path = 'output_detections.mp4'\n",
    "video = cv2.VideoWriter(vide_save_path,cv2.VideoWriter_fourcc(*'DIVX'), 7, (width,height))\n",
    "\n",
    "for image in gif_imgs:\n",
    "    img = (255*image).astype(np.uint8)\n",
    "    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "    video.write(img)\n",
    "video.release()\n",
    "\n",
    "print('Predictions are saved to the video file: ', vide_save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0d1e754",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
